"""Display class to aggregate and print the results of many measurements."""
import collections
import itertools as it
from typing import cast, List, Optional, Tuple

import numpy as np

from torch.utils._benchmark.utils import common

__all__ = ["Compare"]

BEST = "\033[92m"
GOOD = "\033[34m"
BAD = "\033[2m\033[91m"
VERY_BAD = "\033[31m"
BOLD = "\033[1m"
TERMINATE = "\033[0m"


# Classes to separate internal bookkeeping from what is rendered.
class _Column(object):
    def __init__(
        self,
        grouped_results: List[Tuple[common.Measurement, ...]],
        time_scale: float,
        time_unit: str,
        trim_significant_figures: bool,
        highlight_warnings: bool,
    ):
        self._grouped_results = grouped_results
        self._flat_results = list(it.chain(*grouped_results))
        self._time_scale = time_scale
        self._time_unit = time_unit
        self._trim_significant_figures = trim_significant_figures
        self._highlight_warnings = highlight_warnings and any(r.has_warnings for r in self._flat_results)
        leading_digits = [
            int(np.ceil(np.log10(r.median / self._time_scale)))
            for r in self._flat_results
        ]
        unit_digits = max(leading_digits)
        decimal_digits = min(
            max(m.significant_figures - digits, 0)
            for digits, m in zip(leading_digits, self._flat_results)
        ) if self._trim_significant_figures else 1
        length = unit_digits + decimal_digits + (1 if decimal_digits else 0)
        self._template = f"{{:>{length}.{decimal_digits}f}}{{:>{7 if self._highlight_warnings else 0}}}"

    def get_results_for(self, group):
        return self._grouped_results[group]

    def num_to_str(self, value: float, estimated_sigfigs: int, spread: Optional[float]):
        if self._trim_significant_figures:
            value = common.trim_sigfig(value, estimated_sigfigs)
        return self._template.format(
            value,
            f" (! {spread:.0f}%)" if self._highlight_warnings and spread is not None else "")


class _Row(object):
    def __init__(self, results, row_group, render_env, env_str_len,
                 row_name_str_len, time_scale, colorize, num_threads=None):
        super(_Row, self).__init__()
        self._results = results
        self._row_group = row_group
        self._render_env = render_env
        self._env_str_len = env_str_len
        self._row_name_str_len = row_name_str_len
        self._time_scale = time_scale
        self._colorize = colorize
        self._columns = None
        self._num_threads = num_threads

    def register_columns(self, columns: Tuple[_Column, ...]):
        self._columns = columns

    def as_column_strings(self):
        env = f"({self._results[0].env})" if self._render_env else ""
        env = env.ljust(self._env_str_len + 4)
        output = ["  " + env + self._results[0].as_row_name]
        for m, col in zip(self._results, self._columns or ()):
            output.append(col.num_to_str(
                m.median / self._time_scale,
                m.significant_figures,
                m.median / m._iqr if m.has_warnings else None
            ))
        return output

    @staticmethod
    def color_segment(segment, value, group_values):
        best_value = min(group_values)
        if value <= best_value * 1.01 or value <= best_value + 100e-9:
            return BEST + BOLD + segment + TERMINATE * 2
        if value <= best_value * 1.1:
            return GOOD + BOLD + segment + TERMINATE * 2
        if value >= best_value * 5:
            return VERY_BAD + BOLD + segment + TERMINATE * 2
        if value >= best_value * 2:
            return BAD + segment + TERMINATE * 2

        return segment

    def row_separator(self, overall_width):
        return (
            [f"{self._num_threads} threads: ".ljust(overall_width, "-")]
            if self._num_threads is not None else []
        )

    def finalize_column_strings(self, column_strings, col_widths):
        row_contents = [column_strings[0].ljust(col_widths[0])]
        for col_str, width, result, column in zip(column_strings[1:], col_widths[1:], self._results, self._columns or ()):
            col_str = col_str.center(width)
            if self._colorize:
                group_medians = [r.median for r in column.get_results_for(self._row_group)]
                col_str = self.color_segment(col_str, result.median, group_medians)
            row_contents.append(col_str)
        return row_contents


class Table(object):
    def __init__(self, results: List[common.Measurement], colorize: bool,
                 trim_significant_figures: bool, highlight_warnings: bool):
        assert len(set(r.label for r in results)) == 1

        self.results = results
        self._colorize = colorize
        self._trim_significant_figures = trim_significant_figures
        self._highlight_warnings = highlight_warnings
        self.label = results[0].label
        self.time_unit, self.time_scale = common.select_unit(
            min(r.median for r in results)
        )

        self.row_keys = common.ordered_unique([self.row_fn(i) for i in results])
        self.row_keys.sort(key=lambda args: args[:2])  # preserve stmt order
        self.column_keys = common.ordered_unique([self.col_fn(i) for i in results])
        self.rows, self.columns = self.populate_rows_and_columns()

    @staticmethod
    def row_fn(m: common.Measurement):
        return m.num_threads, m.env, m.as_row_name

    @staticmethod
    def col_fn(m: common.Measurement):
        return m.description

    def populate_rows_and_columns(self):
        rows, columns = [], []

        ordered_results: List[List[Optional[common.Measurement]]] = [[None for _ in self.column_keys] for _ in self.row_keys]
        row_position = {key: i for i, key in enumerate(self.row_keys)}
        col_position = {key: i for i, key in enumerate(self.column_keys)}
        for r in self.results:
            i = row_position[self.row_fn(r)]
            j = col_position[self.col_fn(r)]
            ordered_results[i][j] = r

        unique_envs = {r.env for r in self.results}
        render_env = len(unique_envs) > 1
        env_str_len = max(len(i) for i in unique_envs) if render_env else 0

        row_name_str_len = max(len(r.as_row_name) for r in self.results)

        prior_num_threads = -1
        prior_env = ""
        row_group = -1
        rows_by_group: List[List[List[Optional[common.Measurement]]]] = []
        for (num_threads, env, _), row in zip(self.row_keys, ordered_results):
            thread_transition = (num_threads != prior_num_threads)
            if thread_transition:
                prior_num_threads = num_threads
                prior_env = ""
                row_group += 1
                rows_by_group.append([])
            rows.append(
                _Row(
                    results=row,
                    row_group=row_group,
                    render_env=(render_env and env != prior_env),
                    env_str_len=env_str_len,
                    row_name_str_len=row_name_str_len,
                    time_scale=self.time_scale,
                    colorize=self._colorize,
                    num_threads=num_threads if thread_transition else None,
                )
            )
            rows_by_group[-1].append(row)
            prior_env = env

        for i in range(len(self.column_keys)):
            grouped_results = cast(
                List[Tuple[common.Measurement, ...]],  # All Nones should be gone.
                [tuple(row[i] for row in g) for g in rows_by_group],
            )
            column = _Column(
                grouped_results=grouped_results,
                time_scale=self.time_scale,
                time_unit=self.time_unit,
                trim_significant_figures=self._trim_significant_figures,
                highlight_warnings=self._highlight_warnings,)
            columns.append(column)

        rows_tuple, columns_tuple = tuple(rows), tuple(columns)
        for ri in rows_tuple:
            ri.register_columns(columns_tuple)
        return rows_tuple, columns_tuple

    def render(self):
        string_rows = [[""] + self.column_keys]
        for r in self.rows:
            string_rows.append(r.as_column_strings())
        num_cols = max(len(i) for i in string_rows)
        for r in string_rows:
            r.extend(["" for _ in range(num_cols - len(r))])

        col_widths = [max(len(j) for j in i) for i in zip(*string_rows)]
        finalized_columns = ["  |  ".join(i.center(w) for i, w in zip(string_rows[0], col_widths))]
        overall_width = len(finalized_columns[0])
        for string_row, row in zip(string_rows[1:], self.rows):
            finalized_columns.extend(row.row_separator(overall_width))
            finalized_columns.append("  |  ".join(row.finalize_column_strings(string_row, col_widths)))
        print("[" + (" " + (self.label or "") + " ").center(overall_width - 2, "-") + "]")
        print("\n".join(finalized_columns))
        print(f"\nTimes are in {common.unit_to_english(self.time_unit)}s ({self.time_unit}).")
        if self._highlight_warnings and any(r.has_warnings for r in self.results):
            print("(! XX%) Measurement has high variance, where XX is the median / IQR * 100.")
        print("\n")


class Compare(object):
    def __init__(self, results: List[common.Measurement]):
        self._results: List[common.Measurement] = []
        self.extend_results(results)
        self._trim_significant_figures = False
        self._colorize = False
        self._highlight_warnings = False

    def extend_results(self, results):
        for r in results:
            if not isinstance(r, common.Measurement):
                raise ValueError(
                    "Expected an instance of `Measurement`, " f"got {type(r)} instead."
                )
        self._results.extend(results)

    def trim_significant_figures(self):
        self._trim_significant_figures = True

    def colorize(self):
        self._colorize = True

    def highlight_warnings(self):
        self._highlight_warnings = True

    def print(self):
        self._render()

    def _render(self):
        results = common.merge_measurements(self._results)
        results = self._group_by_label(results)
        for group in results.values():
            self._layout(group)

    def _group_by_label(self, results):
        grouped_results = collections.defaultdict(list)
        for r in results:
            grouped_results[r.label].append(r)
        return grouped_results

    def _layout(self, results: List[common.Measurement]):
        table = Table(results, self._colorize, self._trim_significant_figures,
                      self._highlight_warnings)
        table.render()
